Source code for nlp_architect.models.cross_doc_coref.system.sieves.sieves

# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import logging
from typing import Tuple

from nlp_architect.common.cdc.cluster import Cluster
from nlp_architect.data.cdc_resources.relations.relation_extraction import RelationExtraction
from nlp_architect.data.cdc_resources.relations.relation_types_enums import RelationType

logger = logging.getLogger(__name__)


[docs]class SieveClusterMerger(object): def __init__(self, excepted_relation: Tuple[RelationType, float], relation_extractor: RelationExtraction): """ Args: excepted_relation: tuple with relation to run in sieve, threshold to merge clusters relation_extractor: """ self.excepted_relation = excepted_relation[0] self.threshold = excepted_relation[1] self.relation_extractor = relation_extractor logger.info('init Sieve, for relation-%s with threshold=%.1f', self.excepted_relation.name, self.threshold)
[docs] def run_sieve(self, cluster_i: Cluster, cluster_j: Cluster) -> bool: """ Args: cluster_i: cluster_j: Returns: bool -> indicating whether to merge clusters (True) or not (False) """ matches = 0 for mention_i in cluster_i.mentions: for mention_j in cluster_j.mentions: match_result = self.relation_extractor.extract_sub_relations( mention_i, mention_j, self.excepted_relation) if match_result == self.excepted_relation: matches += 1 possible_pairs_len = float(len(cluster_i.mentions) * len(cluster_j.mentions)) matches_rate = matches / possible_pairs_len result = False if matches_rate >= self.threshold: result = True return result